# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from datetime import datetime, timedelta, timezone
import azure.cli.command_modules.backup.custom_help as helper
# pylint: disable=too-many-locals
# pylint: disable=import-error
# pylint: disable=unused-argument

import azure.cli.command_modules.backup.custom_common as common

from azure.mgmt.recoveryservicesbackup.activestamp.models import ProtectedItemResource, \
    RestoreRequestResource, BackupRequestResource, RestoreFileSpecs, \
    AzureFileShareBackupRequest, AzureFileshareProtectedItem, AzureFileShareRestoreRequest, \
    TargetAFSRestoreInfo, ProtectionState, ProtectionContainerResource, AzureStorageContainer

from azure.cli.core.util import CLIError
from azure.cli.command_modules.backup._client_factory import protection_containers_cf, protectable_containers_cf, \
    protection_policies_cf, backup_protection_containers_cf, backup_protectable_items_cf, \
    resources_cf, backup_protected_items_cf, protected_items_cf
from azure.cli.core.azclierror import ArgumentUsageError, ValidationError

from azure.mgmt.recoveryservicesbackup.activestamp import RecoveryServicesBackupClient
from azure.cli.core.commands.client_factory import get_mgmt_service_client

from knack.log import get_logger
logger = get_logger(__name__)

fabric_name = "Azure"
backup_management_type = "AzureStorage"
workload_type = "AzureFileShare"


def reconfigure_afs_protection(cmd, item, source_vault_name, source_vault_rg,
                               new_vault_name, new_vault_rg,
                               new_policy_name, retain_as_per_policy, tenant_id):
    """Reconfigure Azure File Share protection to a new vault and policy.

    Steps:
    1. Disable protection (retain or stop based on flag) in source vault.
    2. Unregister storage account container (if no remaining protected items) from source vault.
    3. Ensure storage account is registered / refreshed in destination vault.
    4. Enable protection for the same file share name in destination vault with new policy.
    5. Return the newly protected item from destination vault.
    """
    logger.warning("For Storage reconfigure protection, all backup items within the "
                   "container must have protection disabled first.")

    # 1. Disable in old vault (retain as per policy if requested)
    items_client = protected_items_cf(cmd.cli_ctx)
    disable_protection(cmd, items_client, source_vault_rg, source_vault_name, item,
                       retain_as_per_policy, tenant_id)

    # 2. Unregister container in old vault only if this was the last protected item for that storage account
    _maybe_unregister_storage_account(cmd, backup_protected_items_cf(cmd.cli_ctx), source_vault_rg, source_vault_name,
                                      item.properties.container_name)

    # 3. Enable protection in destination vault - also registers storage account in destination vault
    new_item = enable_for_AzureFileShare(cmd, items_client, new_vault_rg, new_vault_name, item.name,
                                         item.properties.container_name, new_policy_name)
    return new_item


def _maybe_unregister_storage_account(cmd, client, resource_group_name, vault_name, container_name):
    """Unregister the storage account container if no more protected items exist in the source vault."""
    items = common.list_items(cmd, client, resource_group_name, vault_name,
                              workload_type=workload_type, container_name=container_name,
                              container_type=backup_management_type)
    remaining = [pi for pi in items if pi.properties.protection_state.lower() == 'protected']
    if remaining:
        raise ValidationError('Cannot unregister container as other items are still protected.')

    # Attempt unregister
    try:
        containers_client = protection_containers_cf(cmd.cli_ctx)
        unregister_afs_container(cmd, containers_client, vault_name, resource_group_name, container_name)
    except Exception as ex:  # pylint: disable=broad-except
        logger.warning('Skipping unregister workload container of container %s due to a failure: %s.'
                       ' Continuing the operation, but if the container is still registered, it may need to be '
                       'unregistered manually for the operation to succeed.', container_name, str(ex))


def enable_for_AzureFileShare(cmd, client, resource_group_name, vault_name, afs_name,
                              storage_account_name, policy_name):

    # get registered storage accounts
    storage_account = None
    containers_client = backup_protection_containers_cf(cmd.cli_ctx)
    registered_containers = common.list_containers(containers_client, resource_group_name, vault_name, "AzureStorage")
    storage_account = _get_storage_account_from_list(registered_containers, storage_account_name)

    # get unregistered storage accounts
    if storage_account is None:
        unregistered_containers = list_protectable_containers(cmd.cli_ctx, resource_group_name, vault_name)
        storage_account = _get_storage_account_from_list(unregistered_containers, storage_account_name)

        if storage_account is None:
            # refresh containers in the vault
            protection_containers_client = protection_containers_cf(cmd.cli_ctx)
            filter_string = helper.get_filter_string({'backupManagementType': "AzureStorage"})

            refresh_result = protection_containers_client.refresh(vault_name, resource_group_name, fabric_name,
                                                                  filter=filter_string,
                                                                  cls=helper.get_pipeline_response)
            helper.track_refresh_operation(cmd.cli_ctx, refresh_result, vault_name, resource_group_name)

            # refetch the protectable containers after refresh
            unregistered_containers = list_protectable_containers(cmd.cli_ctx, resource_group_name, vault_name)
            storage_account = _get_storage_account_from_list(unregistered_containers, storage_account_name)

            if storage_account is None:
                raise CLIError("Storage account not found or not supported.")

        # register storage account
        protection_containers_client = protection_containers_cf(cmd.cli_ctx)
        properties = AzureStorageContainer(backup_management_type="AzureStorage",
                                           source_resource_id=storage_account.properties.container_id,
                                           workload_type="AzureFileShare")
        param = ProtectionContainerResource(properties=properties)
        result = protection_containers_client.begin_register(vault_name, resource_group_name, fabric_name,
                                                             storage_account.name, param, polling=False,
                                                             cls=helper.get_pipeline_response).result()
        helper.track_register_operation(cmd.cli_ctx, result, vault_name, resource_group_name, storage_account.name)

    protectable_item = _get_protectable_item_for_afs(cmd.cli_ctx, vault_name, resource_group_name, afs_name,
                                                     storage_account)

    if protectable_item is None:
        items_client = backup_protected_items_cf(cmd.cli_ctx)
        item = common.show_item(cmd, items_client, resource_group_name, vault_name, storage_account_name,
                                afs_name, "AzureStorage")
        if item is None:
            raise CLIError(
                "Could not find a fileshare with name " + afs_name +
                " to protect or a protected fileshare of name " + afs_name)
        return item
    policy = common.show_policy(protection_policies_cf(cmd.cli_ctx), resource_group_name, vault_name, policy_name)
    helper.validate_policy(policy)

    helper.validate_azurefileshare_item(protectable_item)

    container_uri = helper.get_protection_container_uri_from_id(protectable_item.id)
    item_uri = helper.get_protectable_item_uri_from_id(protectable_item.id)
    item_properties = AzureFileshareProtectedItem()

    item_properties.policy_id = policy.id
    item_properties.source_resource_id = protectable_item.properties.parent_container_fabric_id
    item = ProtectedItemResource(properties=item_properties)

    result = client.create_or_update(vault_name, resource_group_name, fabric_name,
                                     container_uri, item_uri, item, cls=helper.get_pipeline_response)
    return helper.track_backup_job(cmd.cli_ctx, result, vault_name, resource_group_name)


def backup_now(cmd, client, resource_group_name, vault_name, item, retain_until):

    if retain_until is None:
        retain_until = datetime.now(timezone.utc) + timedelta(days=30)

    container_uri = helper.get_protection_container_uri_from_id(item.id)
    item_uri = helper.get_protected_item_uri_from_id(item.id)
    trigger_backup_request = _get_backup_request(retain_until)

    result = client.trigger(vault_name, resource_group_name, fabric_name,
                            container_uri, item_uri, trigger_backup_request, cls=helper.get_pipeline_response)
    return helper.track_backup_job(cmd.cli_ctx, result, vault_name, resource_group_name)


def _get_backup_request(retain_until):
    trigger_backup_properties = AzureFileShareBackupRequest(recovery_point_expiry_time_in_utc=retain_until)
    trigger_backup_request = BackupRequestResource(properties=trigger_backup_properties)
    return trigger_backup_request


def _get_protectable_item_for_afs(cli_ctx, vault_name, resource_group_name, afs_name, storage_account):
    storage_account_name = storage_account.name
    protection_containers_client = protection_containers_cf(cli_ctx)

    protectable_item = _try_get_protectable_item_for_afs(cli_ctx, vault_name, resource_group_name,
                                                         afs_name, storage_account_name)

    if protectable_item is None:

        filter_string = helper.get_filter_string({'workloadType': "AzureFileShare"})
        result = protection_containers_client.inquire(vault_name, resource_group_name, fabric_name,
                                                      storage_account.name, filter=filter_string,
                                                      cls=helper.get_pipeline_response)

        helper.track_inquiry_operation(cli_ctx, result, vault_name, resource_group_name, storage_account.name)

        protectable_item = _try_get_protectable_item_for_afs(cli_ctx, vault_name, resource_group_name, afs_name,
                                                             storage_account_name)
    return protectable_item


def _try_get_protectable_item_for_afs(cli_ctx, vault_name, resource_group_name, afs_name, storage_account_name):
    backup_protectable_items_client = backup_protectable_items_cf(cli_ctx)

    filter_string = helper.get_filter_string({
        'backupManagementType': backup_management_type,
        'workloadType': workload_type})

    protectable_items_paged = backup_protectable_items_client.list(vault_name, resource_group_name, filter_string)
    protectable_items = helper.get_list_from_paged_response(protectable_items_paged)
    result = protectable_items
    if helper.is_native_name(storage_account_name):
        result = [protectable_item for protectable_item in result
                  if protectable_item.id.split('/')[12] == storage_account_name.lower()]
    else:
        result = [protectable_item for protectable_item in result
                  if protectable_item.properties.parent_container_friendly_name.lower() == storage_account_name.lower()]
    if helper.is_native_name(afs_name):
        result = [protectable_item for protectable_item in result
                  if protectable_item.name.lower() == afs_name.lower()]
    else:
        result = [protectable_item for protectable_item in result
                  if protectable_item.properties.friendly_name.lower() == afs_name.lower()]
    if len(result) > 1:
        raise CLIError("Could not find a unique resource, Please pass native names instead")
    if len(result) == 1:
        return result[0]
    return None


def restore_AzureFileShare(cmd, client, resource_group_name, vault_name, rp_name, item, restore_mode,
                           resolve_conflict, restore_request_type, source_file_type=None, source_file_path=None,
                           target_storage_account_name=None, target_file_share_name=None, target_folder=None,
                           target_resource_group_name=None, tenant_id=None):

    container_uri = helper.get_protection_container_uri_from_id(item.id)
    item_uri = helper.get_protected_item_uri_from_id(item.id)

    # sa_name = item.properties.container_name

    afs_restore_request = AzureFileShareRestoreRequest()
    target_details = None

    afs_restore_request.copy_options = resolve_conflict
    afs_restore_request.recovery_type = restore_mode

    # Try to get source resource ID from storage account first, fallback to item's source resource ID
    try:
        afs_restore_request.source_resource_id = _get_storage_account_id(cmd.cli_ctx,
                                                                         item.properties.container_name.split(';')[-1],
                                                                         item.properties.container_name.split(';')[-2])
        # Check if source_resource_id is null or empty after assignment
        if not afs_restore_request.source_resource_id:
            raise CLIError("Source resource ID is null or empty after retrieval from storage account.")
    except (CLIError) as e:
        logger.warning(
            "Failed to get storage account ID: %s. Falling back to source resource ID from protected item.",
            str(e))
        source_resource_id = _get_source_resource_id_from_item(item)
        if source_resource_id:
            afs_restore_request.source_resource_id = source_resource_id
        else:
            raise CLIError(
                "Unable to retrieve source resource ID. The storage account might have been deleted "
                "and no fallback source resource ID is available.") from e

    afs_restore_request.restore_request_type = restore_request_type

    restore_file_specs = None

    if source_file_path is not None:
        if len(source_file_path) > 99:
            raise ArgumentUsageError("""
            You can only recover a maximum of 99 Files/Folder.
            Please ensure you have provided less than 100 source file paths.
            """)
        restore_file_specs = []
        for filepath in source_file_path:
            restore_file_specs.append(RestoreFileSpecs(path=filepath, file_spec_type=source_file_type,
                                                       target_folder_path=target_folder))

    if restore_mode == "AlternateLocation":
        if target_resource_group_name is None:
            target_resource_group_name = resource_group_name
        target_sa_name, target_sa_rg = helper.get_resource_name_and_rg(
            target_resource_group_name,
            target_storage_account_name)
        target_details = TargetAFSRestoreInfo()
        target_details.name = target_file_share_name
        target_details.target_resource_id = _get_storage_account_id(cmd.cli_ctx, target_sa_name, target_sa_rg)
        afs_restore_request.target_details = target_details

    afs_restore_request.restore_file_specs = restore_file_specs

    trigger_restore_request = RestoreRequestResource(properties=afs_restore_request)

    if helper.has_resource_guard_mapping(cmd.cli_ctx, resource_group_name, vault_name, "RecoveryServicesRestore"):
        # Cross Tenant scenario
        if tenant_id is not None:
            client = get_mgmt_service_client(cmd.cli_ctx, RecoveryServicesBackupClient,
                                             aux_tenants=[tenant_id]).restores
        trigger_restore_request.properties.resource_guard_operation_requests = [
            helper.get_resource_guard_operation_request(
                cmd.cli_ctx, resource_group_name, vault_name, "RecoveryServicesRestore")]

    # Trigger restore
    result = client.begin_trigger(vault_name, resource_group_name, fabric_name, container_uri, item_uri, rp_name,
                                  trigger_restore_request, cls=helper.get_pipeline_response, polling=False).result()

    return helper.track_backup_job(cmd.cli_ctx, result, vault_name, resource_group_name)


def list_recovery_points(cmd, client, resource_group_name, vault_name, item, start_date=None, end_date=None,
                         use_secondary_region=None, is_ready_for_move=None, target_tier=None, tier=None,
                         recommended_for_archive=None):
    if use_secondary_region:
        raise ArgumentUsageError(
            """
            --use-secondary-region flag is not supported for --backup-management-type AzureStorage.
            Please either remove the flag or query for any other backup-management-type.
            """)

    if is_ready_for_move is not None or target_tier is not None:
        raise ArgumentUsageError("""Invalid argument has been passed. --is-ready-for-move true, --target-tier
        are not supported for --backup-management-type AzureStorage.""")

    if recommended_for_archive is not None:
        raise ArgumentUsageError("""--recommended-for-archive is supported by AzureIaasVM backup management
        type only.""")

    if cmd.name.split()[2] == 'show-log-chain':
        raise ArgumentUsageError("show-log-chain is supported by AzureWorkload backup management type only.")

    # Get container and item URIs
    container_uri = helper.get_protection_container_uri_from_id(item.id)
    item_uri = helper.get_protected_item_uri_from_id(item.id)

    query_end_date, query_start_date = helper.get_query_dates(end_date, start_date)

    filter_string = helper.get_filter_string({
        'startDate': query_start_date,
        'endDate': query_end_date})

    # Get recovery points
    recovery_points = client.list(vault_name, resource_group_name, fabric_name, container_uri, item_uri, filter_string)
    paged_recovery_points = helper.get_list_from_paged_response(recovery_points)

    if tier:
        filtered_recovery_points = []

        for rp in paged_recovery_points:
            # Prepare to collect tier types
            rp_tier_types = []

            # Safely grab additional_properties
            additional_props = getattr(rp.properties, 'additional_properties', {})
            if not isinstance(additional_props, dict):
                continue

            # Get details list
            tier_details_list = additional_props.get("recoveryPointTierDetails", [])
            if not isinstance(tier_details_list, list):
                continue

            for detail in tier_details_list:
                if not isinstance(detail, dict):
                    continue
                rp_type = detail.get("type")
                if rp_type:
                    rp_tier_types.append(rp_type)

            # Map types to a tier
            if 'InstantRP' in rp_tier_types and 'HardenedRP' in rp_tier_types:
                rp_tier = 'SnapshotAndVaultStandard'
            elif 'InstantRP' in rp_tier_types:
                rp_tier = 'Snapshot'
            elif 'HardenedRP' in rp_tier_types:
                rp_tier = 'VaultStandard'
            else:
                logger.warning(
                    "Warning: Unrecognized Recovery Point tier received."
                    "If you see this message, please contact Microsoft Support."
                    "The recognized tiers for AzureFileShare are: 'Snapshot', 'VaultStandard', or "
                    "'SnapshotAndVaultStandard'."
                )
                rp_tier = None

            # Filter by matching tier
            if rp_tier == tier:
                filtered_recovery_points.append(rp)

        return filtered_recovery_points

    return paged_recovery_points


def update_policy_for_item(cmd, client, resource_group_name, vault_name, item, policy, tenant_id=None,
                           is_critical_operation=False, yes=False):
    if item.properties.backup_management_type != policy.properties.backup_management_type:
        raise CLIError(
            """
            The policy type should match with the workload being protected.
            Use the relevant get-default policy command and use it to update the policy for the workload.
            """)

    # Get container and item URIs
    container_uri = helper.get_protection_container_uri_from_id(item.id)
    item_uri = helper.get_protected_item_uri_from_id(item.id)

    # Update policy request
    afs_item_properties = AzureFileshareProtectedItem()
    afs_item_properties.policy_id = policy.id
    afs_item_properties.source_resource_id = item.properties.source_resource_id
    afs_item = ProtectedItemResource(properties=afs_item_properties)
    if is_critical_operation:
        existing_policy_name = item.properties.policy_id.split('/')[-1]
        existing_policy = common.show_policy(protection_policies_cf(cmd.cli_ctx), resource_group_name, vault_name,
                                             existing_policy_name)
        if helper.is_retention_duration_decreased(existing_policy, policy, "AzureStorage"):
            # update the payload with critical operation and add auxiliary header for cross tenant case
            if tenant_id is not None:
                client = get_mgmt_service_client(cmd.cli_ctx, RecoveryServicesBackupClient,
                                                 aux_tenants=[tenant_id]).protected_items
            afs_item.properties.resource_guard_operation_requests = [helper.get_resource_guard_operation_request(
                cmd.cli_ctx, resource_group_name, vault_name, "updateProtection")]

    # Validate existing & new policy
    existing_policy_name = item.properties.policy_id.split('/')[-1]
    existing_policy = common.show_policy(protection_policies_cf(cmd.cli_ctx), resource_group_name, vault_name,
                                         existing_policy_name)
    helper.validate_update_policy_request(existing_policy, policy, yes)

    # Update policy
    result = client.create_or_update(vault_name, resource_group_name, fabric_name,
                                     container_uri, item_uri, afs_item, cls=helper.get_pipeline_response)
    return helper.track_backup_job(cmd.cli_ctx, result, vault_name, resource_group_name)


def disable_protection(cmd, client, resource_group_name, vault_name, item,
                       retain_recovery_points_as_per_policy=False, tenant_id=None):
    # Get container and item URIs
    container_uri = helper.get_protection_container_uri_from_id(item.id)
    item_uri = helper.get_protected_item_uri_from_id(item.id)

    afs_item_properties = AzureFileshareProtectedItem()
    afs_item_properties.policy_id = ''
    if retain_recovery_points_as_per_policy:
        afs_item_properties.protection_state = ProtectionState.backups_suspended
    else:
        afs_item_properties.protection_state = ProtectionState.protection_stopped
    afs_item_properties.source_resource_id = item.properties.source_resource_id
    afs_item = ProtectedItemResource(properties=afs_item_properties)

    # ResourceGuard scenario: if we are stopping backup and there is MUA setup for the scenario,
    # we want to set the appropriate parameters.
    if afs_item.properties.protection_state == ProtectionState.protection_stopped:
        if helper.has_resource_guard_mapping(cmd.cli_ctx, resource_group_name,
                                             vault_name, "RecoveryServicesStopProtection"):
            # Cross Tenant scenario
            if tenant_id is not None:
                client = get_mgmt_service_client(cmd.cli_ctx, RecoveryServicesBackupClient,
                                                 aux_tenants=[tenant_id]).protected_item
            afs_item.properties.resource_guard_operation_requests = [helper.get_resource_guard_operation_request(
                cmd.cli_ctx, resource_group_name, vault_name, "RecoveryServicesStopProtection")]

    result = client.create_or_update(vault_name, resource_group_name, fabric_name,
                                     container_uri, item_uri, afs_item, cls=helper.get_pipeline_response)
    return helper.track_backup_job(cmd.cli_ctx, result, vault_name, resource_group_name)


def resume_protection(cmd, client, resource_group_name, vault_name, item, policy):
    return update_policy_for_item(cmd, client, resource_group_name, vault_name, item, policy)


def _get_storage_account_id(cli_ctx, storage_account_name, storage_account_rg):
    resources_client = resources_cf(cli_ctx)
    classic_storage_resource_namespace = 'Microsoft.ClassicStorage'
    storage_resource_namespace = 'Microsoft.Storage'
    parent_resource_path = 'storageAccounts'
    resource_type = ''
    classic_api_version = '2015-12-01'
    api_version = '2016-01-01'

    storage_account = None
    try:
        storage_account = resources_client.get(storage_account_rg, classic_storage_resource_namespace,
                                               parent_resource_path, resource_type, storage_account_name,
                                               classic_api_version)
    except:  # pylint: disable=bare-except
        storage_account = resources_client.get(storage_account_rg, storage_resource_namespace, parent_resource_path,
                                               resource_type, storage_account_name, api_version)
    return storage_account.id


def _get_source_resource_id_from_item(item):
    """
    Helper function to retrieve source resource ID from a protected item.
    This is used as a fallback when the storage account is deleted.
    """
    if item and hasattr(item, 'properties') and hasattr(item.properties, 'source_resource_id'):
        return item.properties.source_resource_id
    return None


def set_policy(cmd, client, resource_group_name, vault_name, policy, policy_name, tenant_id=None,
               is_critical_operation=False, yes=False):
    if policy_name is None:
        raise CLIError(
            """
            Policy name is required for set policy.
            """)

    policy_object = helper.get_policy_from_json(client, policy)
    policy_object.properties.work_load_type = workload_type
    existing_policy = common.show_policy(client, resource_group_name, vault_name, policy_name)

    helper.validate_update_policy_request(existing_policy, policy_object, yes)
    if is_critical_operation:
        if helper.is_retention_duration_decreased(existing_policy, policy_object, "AzureStorage"):
            # update the payload with critical operation and add auxiliary header for cross tenant case
            if tenant_id is not None:
                client = get_mgmt_service_client(cmd.cli_ctx, RecoveryServicesBackupClient,
                                                 aux_tenants=[tenant_id]).protection_policies
            policy_object.properties.resource_guard_operation_requests = [helper.get_resource_guard_operation_request(
                cmd.cli_ctx, resource_group_name, vault_name, "updatePolicy")]
    return client.create_or_update(vault_name, resource_group_name, policy_name, policy_object)


def create_policy(client, resource_group_name, vault_name, name, policy):
    policy_object = helper.get_policy_from_json(client, policy)
    policy_object.name = name
    if backup_management_type is not None:
        policy_object.properties.backup_management_type = backup_management_type
    policy_object.properties.work_load_type = workload_type
    return client.create_or_update(vault_name, resource_group_name, name, policy_object)


def unregister_afs_container(cmd, client, vault_name, resource_group_name, container_name):
    result = client.unregister(vault_name, resource_group_name, fabric_name, container_name,
                               cls=helper.get_pipeline_response)
    return helper.track_register_operation(cmd.cli_ctx, result, vault_name, resource_group_name, container_name)


def list_protectable_containers(cli_ctx, resource_group_name, vault_name):
    filter_string = helper.get_filter_string({
        'backupManagementType': "AzureStorage"})

    client = protectable_containers_cf(cli_ctx)
    paged_containers = client.list(vault_name, resource_group_name, fabric_name, filter_string)
    return helper.get_list_from_paged_response(paged_containers)


def _get_storage_account_from_list(container_list, storage_account_name):
    storage_account = None
    for container in container_list:
        if helper.is_native_name(storage_account_name) and container.name == storage_account_name:
            return container
        friendly_name = container.properties.friendly_name
        if not helper.is_native_name(storage_account_name) and friendly_name == storage_account_name:
            if storage_account is not None:
                raise CLIError("multiple storage accounts found. Please provide native names instead")
            storage_account = container
    return storage_account
